Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FSDP Support to Training Library #213

Merged
merged 47 commits into from
Sep 26, 2024
Merged

Conversation

aldopareja
Copy link
Member

@aldopareja aldopareja commented Sep 18, 2024

Adds support for FSDP and FSDP w/ CPU Offloading.

Introduces accelerate as a distributed backend abstraction (for FSDP/DeepSpeed)
Also fixes mistral template and cleans up data processing.

-Mustafa

@mergify mergify bot added the ci-failure label Sep 18, 2024
@mergify mergify bot added ci-failure dependencies Pull requests that update a dependency file and removed ci-failure labels Sep 18, 2024
@Maxusmusti Maxusmusti changed the title Ap/accelerate fsdp tmp2 Adding FSDP Support to Training Library Sep 24, 2024
@mergify mergify bot added ci-failure CI/CD Affects CI/CD configuration and removed ci-failure labels Sep 24, 2024
This was referenced Sep 24, 2024
src/instructlab/training/config.py Show resolved Hide resolved
src/instructlab/training/main_ds.py Outdated Show resolved Hide resolved
src/instructlab/training/main_ds.py Show resolved Hide resolved
src/instructlab/training/main_ds.py Outdated Show resolved Hide resolved
src/instructlab/training/main_ds.py Show resolved Hide resolved
src/instructlab/training/main_ds.py Show resolved Hide resolved
src/instructlab/training/main_ds.py Outdated Show resolved Hide resolved
src/instructlab/training/utils.py Show resolved Hide resolved
src/instructlab/training/utils.py Outdated Show resolved Hide resolved
src/instructlab/training/utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

mergify bot commented Sep 25, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @aldopareja please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the one-approval label Sep 25, 2024
Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
@mergify mergify bot added documentation Improvements or additions to documentation ci-failure labels Sep 25, 2024
…ining_backend to TrainingArgs.distributed_backend and DistributedTrainingBackend to DistributedBackend

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 25, 2024
Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 25, 2024
@@ -157,6 +181,12 @@ class TrainingArgs(BaseModel):
cpu_offload_optimizer_pin_memory=False,
)
)
fsdp_options: FSDPOptions = Field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be a factory? I think it can just be an assignment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the current convention set by DeepSpeedOptions in the file, so imo if we want to change this, we should make a follow-up PR that updates both of them

reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to expose this ever? This adds a bit of memory overhead for some performance- I think customarily it's probably a default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, I think it's fine for now, but I will open an issue to track this, as I'm not sure how much of a performance hit compared to memory gain this option will be for us. Might be a nice bonus trick to avoid offloading in some configurations if performance isn't horrendous

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracked in #228

}
return ds_config
def setup_optimizer(args, model):
if args.distributed_training_framework == "fsdp":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typical way to do this is via this pattern:

Suggested change
if args.distributed_training_framework == "fsdp":
if DistributedBackend(args.distributed_training_framework) == DistributedBackend.FSDP:

This collects "magic strings" like "fsdp" would be into the Enum object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it actually has to be DistributedBackend.FSDP.value, since by this point the args have gone through the main_ds argparse post-torchrun and args.distributed_training_framework is just a string

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
)
accelerator = setup_accelerator(args, model, grad_accum)
if args.distributed_training_framework == "fsdp":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same enum trick here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it actually has to be DistributedBackend.FSDP.value, since by this point the args have gone through the main_ds argparse post-torchrun and args.distributed_training_framework is just a string

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

),
lr_scheduler=lr_scheduler,
dist_init_required=True,
model, optimizer, _, lr_scheduler = accelerator.prepare(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see here that we're "double preparing" the model- is that okay? Is Accelerate smart enough to handle this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have verified that it is, originally I had some conditionals to avoid it but accelerate was one step ahead

global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen this here conventionally, only at the top of the training loop. I guess it can be either place. I also see that this is where they put it in the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it aint broke 🤷🏻‍♂️

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO nothing that I noticed is blocking an approval. The only thing that I really want is for this PR to be rebased as a single commit so the history is a bit neater. Once that's done I'll approve!

Signed-off-by: Mustafa Eyceoz <meyceoz@redhat.com>
Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@mergify mergify bot removed the one-approval label Sep 26, 2024
@Maxusmusti Maxusmusti merged commit 7b7fa12 into main Sep 26, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI/CD Affects CI/CD configuration dependencies Pull requests that update a dependency file documentation Improvements or additions to documentation hold
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants